import torch
import torch.nn as nn
import pickle
import random
import numpy as np
from torch.utils.data import DataLoader
from predictnext3 import Predictor, num_references, num_features, max_spacing
from music21 import *
import pitchprominence
from inspect import getmembers, isfunction

model = Predictor()
model.load_state_dict(torch.load("graphnn/predictprogram.pth"))
feat_names = pickle.load(open("../symmfuncs.pcl", "rb"))


#midiprobs = pickle.load(open("../midiprobs.pcl", "rb"))
n_tot = 1000000
def transposeToC(xss):
	pcs = range(-6,7)
	pcs_ordering = {0:7, 7:5, 5:4, 4:2, 2:2, 11:2, 9:2, 3:0, 8:0, 10:-1, 1:-5, 6:-5}
	transposed = [[[(x[0] + pc, x[1]) for x in xs] for xs in xss] for pc in pcs]
	transposed_best = max(transposed, key = lambda i: sum([sum([pcs_ordering[k[0] % 12] * k[1] for k in j]) for j in i]))
	return transposed_best

functions_list = [o for o in getmembers(pitchprominence) if isfunction(o[1])]

def getProfile(measure):
	measure_profile = np.zeros(12)
	for pc in range(12):
		try:
			for (name, func) in functions_list:
				(has_feat, weight) = func(measure, pc)
				if has_feat:
					measure_profile[pc] += weight
		except: print(measure)
	return measure_profile

all_spacings = []
all_refs = []
all_feats = []
ind = 0
ref_mean = []
aheads = []
while ind < n_tot:
	prev_refs = [random.randint(0,4) for i in range(3)] 
	prev_feats = [{k:True for k in feat_names} for j in range(3)] #initialize features
	#initialize spacings
	prev_spacings = [0,0,0]
	prev_vec1 = np.zeros((6, num_references))
	prev_vec2 = np.zeros((6, max_spacing + 1))
	prev_vec1[4,random.randint(0,4)] = 1
	prev_vec1[5,random.randint(0,4)] = 1
	prev_vec1[3,random.randint(0,4)] = 1
	prev_vec2[3,random.randint(0,1)] = 1
	prev_vec2[4,random.randint(0,1)] = 1
	prev_vec2[5,random.randint(0,1)] = 1
	prev_vec3 = np.ones(2*num_features)
	prev_vec = np.concatenate(np.concatenate([prev_vec1, prev_vec2], axis=1))
	prev_vec = np.concatenate([prev_vec, prev_vec3], axis=0)
	while len(prev_refs) < 16: #create 20-measure piece
		spacing, refs, ahead = model(torch.from_numpy(prev_vec).float()) 
		ahead = ahead + 0.4*torch.randn(ahead.shape)
		spacing = spacing + 0.01*torch.randn(spacing.shape)
		spacing = int(torch.argmax(spacing))

		ahead = int(torch.argmax(ahead))
		aheads.append(ahead)
		#print(np.mean(aheads))
		refs = refs + 0.1*torch.randn(refs.shape)
		if random.uniform(0,1) < 0.3:
			refs = np.ones(refs[0].shape)
		else:
			refs = (refs > 0.73)[0].numpy()
		#print(refs)
		ref_mean.append(sum(refs))
		new_vec1 = np.zeros((6, num_references))
		new_vec2 = np.zeros((6, max_spacing + 1))

		for q_ in range(ahead + 1): #we're choosing "ahead" at a time
			for i in range(5):
				new_vec1[i,:] = prev_vec1[i - 1,:]
				new_vec2[i,:] = prev_vec2[i - 1,:]
			#get index of reference
			ref = prev_refs[-1*spacing] if (spacing != 0 and spacing < len(prev_refs) and (len(prev_refs) < 9 or (len(prev_refs[:-9]) >= 3 and len(prev_refs[:-4]) > 1))) else max([k for k in range(5) if all([k_ in prev_refs for k_ in range(k)])]) if (len(set(prev_refs)) < 5 and random.uniform(0,1) < 0.5) else random.randint(0,4)
			new_vec2[-1, spacing] = 1
			new_vec1[-1, ref] = 1
			prev_refs.append(ref)
			prev_spacings.append(spacing)

			new_vec3 = np.zeros(2*num_features)
			new_vec3[:num_features] = prev_vec3[num_features:]
			new_vec3[num_features:] = refs


			prev_feats.append({k:refs[feat_names.index(k)] > 0.5 for k in feat_names}) #get symmetries

			prev_vec1 = new_vec1
			pref_vec2 = new_vec2
			prev_vec3 = new_vec3
			prev_vec = np.concatenate(np.concatenate([prev_vec1, prev_vec2], axis=1))
			prev_vec = np.concatenate([prev_vec, prev_vec3], axis=0)
	if len(set(prev_refs)) == 5 and len([i for i in range(len(prev_refs) - 6) if len(set(prev_refs[i:i + 6])) == 1]) == 0 and len([i for i in range(6) if prev_refs.count(i) < 2]) < 2:
		all_spacings.append(prev_spacings[:16])
		print(prev_refs[:16])
		all_feats.append(prev_feats[:16])
		all_refs.append(prev_refs[:16])
		#print(prev_refs)
		ind += 1
		print((ind, np.mean(ref_mean)))
	
		
#get actual measures of references from latent vectors
all_notes = []
all_ref_measures = []
all_prev_profiles = []

for z in range(n_tot):
	ref_measures = []
	try:
		mid = [list(converter.parse("../referencemids/" + str(z) + ".mid"))[0] for k in range(6)]
	except:
		print("mid error")
	notes = [[]]
	prev_note = 60
	s = list(mid[0])
	for val in s:
		if type(val) == note.Note:
			notes[-1].append((val.pitch.midi, val.quarterLength))
			prev_note = val.pitch.midi
		elif type(val) == note.Rest:
			if len(notes[-1]) > 0:
				notes[-1][-1] = (notes[-1][-1][0], notes[-1][-1][1] + val.quarterLength)
			else:
				notes[-1].append((prev_note, val.quarterLength))
	
		if sum([i[1] for i in notes[-1]]) > 4.0:
			print("bad in")
		if sum([i[1] for i in notes[-1]]) == 4.0:
			if all([k[1] == 0.25 for k in notes[-1]]):
				to_del = random.choice([0,4,8,12])
				notes[-1][to_del] = (notes[-1][to_del][0], 0.5)
				del notes[-1][to_del + 1]
			notes.append([])
	for j in all_refs[z]: #get list of ref measures in predicted spacings
		ref_measures.append(notes[j])


pickle.dump(all_ref_measures, open("pickle/allrefmeasures.pcl", "wb"))
pickle.dump((all_feats), open("pickle/reference_features.pcl", "wb"))
pickle.dump(all_refs, open("pickle/allrefs.pcl", "wb"))
